# ======================================================================
# Copyright CERFACS (March 2019)
# Contributor: Adrien Suau (adrien.suau@cerfacs.fr)
#
# This software is governed by the CeCILL-B license under French law and
# abiding  by the  rules of  distribution of free software. You can use,
# modify  and/or  redistribute  the  software  under  the  terms  of the
# CeCILL-B license as circulated by CEA, CNRS and INRIA at the following
# URL "http://www.cecill.info".
#
# As a counterpart to the access to  the source code and rights to copy,
# modify and  redistribute granted  by the  license, users  are provided
# only with a limited warranty and  the software's author, the holder of
# the economic rights,  and the  successive licensors  have only limited
# liability.
#
# In this respect, the user's attention is drawn to the risks associated
# with loading,  using, modifying and/or  developing or reproducing  the
# software by the user in light of its specific status of free software,
# that  may mean  that it  is complicated  to manipulate,  and that also
# therefore  means that  it is reserved for  developers and  experienced
# professionals having in-depth  computer knowledge. Users are therefore
# encouraged  to load and  test  the software's  suitability as  regards
# their  requirements  in  conditions  enabling  the  security  of their
# systems  and/or  data to be  ensured and,  more generally,  to use and
# operate it in the same conditions as regards security.
#
# The fact that you  are presently reading this  means that you have had
# knowledge of the CeCILL-B license and that you accept its terms.
# ======================================================================

"""This module contains functions dealing with the linear algebra formalism in quantum.

All the functions in this module are here to ease the manipulation of
  1. Quantum states as vectors of complex numbers.
  2. Quantum gates as matrices of complex numbers.
"""

import numpy

from qaths.utils import constants
from qaths.utils.math import kron

zero = numpy.array([1.0, 0.0], dtype=numpy.complex)
one = numpy.array([0.0, 1.0], dtype=numpy.complex)


def qstate_msb(state: str) -> numpy.ndarray:
    """Compute the vector representing the given state.

    :param state: A string representing a quantum state. The string will be read in
        MSB format (i.e. the first character will be the **most** significant bit).
    :return: A vector representing the given state in the **MSB** format (according to
        qaths convention of expressing everything in MSB).
    """
    return kron(*[zero if c == "0" else one for c in state])


def qstate_lsb(state: str) -> numpy.ndarray:
    """Compute the vector representing the given state.

    :param state: A string representing a quantum state. The string will be read in
        LSB format (i.e. the first character will be the **least** significant bit).
    :return: A vector representing the given state in the **MSB** format (according to
        qaths convention of expressing everything in MSB).
    """
    return qstate_msb(state[::-1])


def complete_by_zeros(
    state: numpy.ndarray, number_of_zeros: int, front: bool = False
) -> numpy.ndarray:
    r"""Complete the given state with qubits initialised to :math:`\ket{0}`.

    :param state: The complex vector representing a quantum state.
    :param number_of_zeros: The number of qubits initialised to :math:`\ket{0}`
        to add to the given quantum state.
    :param front: True if the qubits should be added in front of the given state.
        Default is to add the qubits at the end of the state (False).
    :return: a larger quantum state where all the new qubits are in the
        :math:`\ket{0}` state.
    """
    if front:
        return kron(*([zero for _ in range(number_of_zeros)] + [state]))
    else:
        return kron(state, *[zero for _ in range(number_of_zeros)])


def qstate_fidelity(qstate1: numpy.ndarray, qstate2: numpy.ndarray) -> float:
    r"""Compute the fidelity between two quantum states.

    The fidelity measure implement is a good measure for quantum states because it does
    not depend on the global phase between the two given quantum states.

    :param qstate1: First state to compare.
    :param qstate2: Second state to compare.
    :return: A real number in :math:`[0, 1]`, :math:`1` meaning that the two given
        quantum states are the same and :math:`0` meaning that they are orthogonal.
    """
    return numpy.square(numpy.abs(numpy.vdot(qstate1, qstate2)))


def assert_transformation_allclose(
    transformation1: numpy.ndarray,
    transformation2: numpy.ndarray,
    atol: float = constants.ABSOLUTE_TOLERANCE,
    nnz_atol: float = constants.UNDER_IS_ZERO,
) -> None:
    """Test if the two given transformations are equal.

    :param transformation1: The first transformation to compare.
    :param transformation2: The second transformation to compare.
    :param atol: The absolute tolerance of the comparison.
    :param nnz_atol: A limit under which all the values are considered as zero.
    """
    # First perform a stupid compare. We don't assert on this one because there are
    # cases where the following can be False and the transformations can still be
    # considered equal.
    if numpy.allclose(transformation1, transformation2, atol=atol):
        # If this is verified, then the assert is verified. Exit gracefully.
        return

    # The only option left here is that the two given transformations differ from
    # a global phase. Lets compute it.
    # First take the nonzeros elements in both transformations because we don't want to
    # divide by zero.
    nnz1 = numpy.abs(transformation1) > numpy.abs(nnz_atol)
    nnz2 = numpy.abs(transformation2) > numpy.abs(nnz_atol)
    # If the non-zero positions don't match here, we don't need to continue.
    numpy.testing.assert_equal(nnz1, nnz2)

    # Finally check that there is only a global phase.
    factors = transformation1[nnz1] / transformation2[nnz2]
    expected = factors[0] * numpy.ones(factors.shape)
    numpy.testing.assert_allclose(factors, expected, atol=atol)


def assert_qstate_allclose(
    qstate1: numpy.ndarray,
    qstate2: numpy.ndarray,
    atol: float = constants.ABSOLUTE_TOLERANCE,
) -> None:
    """Assert that the given quantum states are close.

    :param qstate1: First quantum state.
    :param qstate2: Second quantum state.
    :param atol: Absolute tolerance.
    """
    fidelity = qstate_fidelity(qstate1, qstate2)
    assert (1.0 - fidelity) < atol
